import math 
import torch
from torch.autograd import Variable
import numpy as np


def L( r ):
	if   r <= math.pi / 4:
		return 2 
	elif r >= math.pi / 2:
		return 0 
	else:
		return 2 * math.cos( math.pi / 2 * math.log( 4 * r / math.pi ) / math.log( 2 ) ) 

def H( r ):
	if   r <= math.pi / 4:
		return 0 
	elif r >= math.pi / 2:
		return 1 
	else:
		return     math.cos( math.pi / 2 * math.log( 2 * r / math.pi ) / math.log( 2 ) ) 

def G( t, k, K ):

	t0 = math.pi * k / K 
	aK = 2**(K-1) * math.factorial(K-1) / math.sqrt( K * math.factorial( 2 * (K-1) ) )

	if (t - t0) > (math.pi/2):
		return G( t - math.pi, k, K ) 
	elif (t - t0 ) < (-math.pi/2):
		return G( t + math.pi, k, K )
	else:
		return aK * (math.cos( t - t0 ))**(K-1)

def S( t, k, K ):

	t0 = math.pi * k / K 
	dt = abs(t-t0)

	if   dt <  math.pi/2:
		return 1 
	elif dt == math.pi/2:
		return 0
	else:
		return -1 

def L0( r ):
	return L( r/2 ) / 2 

def H0( r ):
	return H( r/2 )

def polar_map( s ):

	x = torch.linspace(        0, math.pi, s[1] ).view( 1, s[1] ).expand( s )
	if s[0] % 2 == 0 :
		y = torch.linspace( -math.pi, math.pi, s[0]+1 ).narrow(0,1,s[0])
	else:
		y = torch.linspace( -math.pi, math.pi, s[0]   )
	y = y.view( s[0], 1 ).expand( s ).mul( -1 )

	r = ( x**2 + y**2 ).sqrt()
	t = torch.atan2( y, x )

	return r, t 

def S_matrix( K, s ):

	_, t = polar_map( s )
	sm = torch.Tensor( K, s[0], s[1] )

	for k in range( K ):
		for i in range( s[0] ):
			for j in range( s[1] ):
				sm[k][i][j] = S( t[i][j], k, K )

	return sm 

def G_matrix( K, s ):

	_, t = polar_map( s ) 
	g = torch.Tensor( K, s[0], s[1] ) 

	for k in range( K ):
		for i in range( s[0] ):
			for j in range( s[1] ):
				g[k][i][j] = G( t[i][j], k, K ) 

	return g 

def B_matrix( K, s ):

	g = G_matrix( K, s ) 

	r, _ = polar_map( s )
	h = r.apply_( H ).unsqueeze(0)

	return h * g 

def L_matrix( s ):

	r, _ = polar_map( s )

	return r.apply_( L )

def LB_matrix( K, s ):

	l = L_matrix( s ).unsqueeze(0) 
	b = B_matrix( K, s ) 

	return torch.cat( (l,b), 0 )

def HL0_matrix( s ):

	r, _ = polar_map( s )
	h = r.clone().apply_( H0 ).view( 1, s[0], s[1] ) 
	l = r.clone().apply_( L0 ).view( 1, s[0], s[1] )

	return torch.cat( ( h, l ), 0 )

def central_crop( x ):

	ns = [ x.size(-2)//2 , x.size(-1)//2 + 1 ]

	return x.narrow( -2, ns[1]-1, ns[0] ).narrow( -1, 0, ns[1] )

def cropped_size( s ):

	return [ s[0]//2 , s[1]//2 + 1 ]

def L_matrix_cropped( s ):

	l = L_matrix( s ) 

	ns = cropped_size( s ) 

	return l.narrow( 0, ns[1]-1, ns[0] ).narrow( 1, 0, ns[1] ) 

def freq_shift( imgSize, fwd, device ):
    ind = torch.LongTensor( imgSize ).to(device)
    sgn = 1 
    if fwd:
        sgn = -1 
    for i in range( imgSize ):
        ind[i] = (i + sgn*((imgSize-1)//2) ) % imgSize

    return Variable( ind ) 


##########
def sp5_filters():
    filters = {}
    filters['harmonics'] = np.array([1, 3, 5])
    filters['mtx'] = (
        np.array([[0.3333, 0.2887, 0.1667, 0.0000, -0.1667, -0.2887],
                  [0.0000, 0.1667, 0.2887, 0.3333, 0.2887, 0.1667],
                  [0.3333, -0.0000, -0.3333, -0.0000, 0.3333, -0.0000],
                  [0.0000, 0.3333, 0.0000, -0.3333, 0.0000, 0.3333],
                  [0.3333, -0.2887, 0.1667, -0.0000, -0.1667, 0.2887],
                  [-0.0000, 0.1667, -0.2887, 0.3333, -0.2887, 0.1667]]))
    filters['hi0filt'] = (
        np.array([[-0.00033429, -0.00113093, -0.00171484,
                   -0.00133542, -0.00080639, -0.00133542,
                   -0.00171484, -0.00113093, -0.00033429],
                  [-0.00113093, -0.00350017, -0.00243812,
                   0.00631653, 0.01261227, 0.00631653,
                   -0.00243812, -0.00350017, -0.00113093],
                  [-0.00171484, -0.00243812, -0.00290081,
                   -0.00673482, -0.00981051, -0.00673482,
                   -0.00290081, -0.00243812, -0.00171484],
                  [-0.00133542, 0.00631653, -0.00673482,
                   -0.07027679, -0.11435863, -0.07027679,
                   -0.00673482, 0.00631653, -0.00133542],
                  [-0.00080639, 0.01261227, -0.00981051,
                   -0.11435863, 0.81380200, -0.11435863,
                   -0.00981051, 0.01261227, -0.00080639],
                  [-0.00133542, 0.00631653, -0.00673482,
                   -0.07027679, -0.11435863, -0.07027679,
                   -0.00673482, 0.00631653, -0.00133542],
                  [-0.00171484, -0.00243812, -0.00290081,
                   -0.00673482, -0.00981051, -0.00673482,
                   -0.00290081, -0.00243812, -0.00171484],
                  [-0.00113093, -0.00350017, -0.00243812,
                   0.00631653, 0.01261227, 0.00631653,
                   -0.00243812, -0.00350017, -0.00113093],
                  [-0.00033429, -0.00113093, -0.00171484,
                   -0.00133542, -0.00080639, -0.00133542,
                   -0.00171484, -0.00113093, -0.00033429]]))
    filters['lo0filt'] = (
        np.array([[0.00341614, -0.01551246, -0.03848215, -0.01551246,
                  0.00341614],
                 [-0.01551246, 0.05586982, 0.15925570, 0.05586982,
                  -0.01551246],
                 [-0.03848215, 0.15925570, 0.40304148, 0.15925570,
                  -0.03848215],
                 [-0.01551246, 0.05586982, 0.15925570, 0.05586982,
                  -0.01551246],
                 [0.00341614, -0.01551246, -0.03848215, -0.01551246,
                  0.00341614]]))
    filters['lofilt'] = (
        2 * np.array([[0.00085404, -0.00244917, -0.00387812, -0.00944432,
                       -0.00962054, -0.00944432, -0.00387812, -0.00244917,
                       0.00085404],
                      [-0.00244917, -0.00523281, -0.00661117, 0.00410600,
                       0.01002988, 0.00410600, -0.00661117, -0.00523281,
                       -0.00244917],
                      [-0.00387812, -0.00661117, 0.01396746, 0.03277038,
                       0.03981393, 0.03277038, 0.01396746, -0.00661117,
                       -0.00387812],
                      [-0.00944432, 0.00410600, 0.03277038, 0.06426333,
                       0.08169618, 0.06426333, 0.03277038, 0.00410600,
                       -0.00944432],
                      [-0.00962054, 0.01002988, 0.03981393, 0.08169618,
                       0.10096540, 0.08169618, 0.03981393, 0.01002988,
                       -0.00962054],
                      [-0.00944432, 0.00410600, 0.03277038, 0.06426333,
                       0.08169618, 0.06426333, 0.03277038, 0.00410600,
                       -0.00944432],
                      [-0.00387812, -0.00661117, 0.01396746, 0.03277038,
                       0.03981393, 0.03277038, 0.01396746, -0.00661117,
                       -0.00387812],
                      [-0.00244917, -0.00523281, -0.00661117, 0.00410600,
                       0.01002988, 0.00410600, -0.00661117, -0.00523281,
                       -0.00244917],
                      [0.00085404, -0.00244917, -0.00387812, -0.00944432,
                       -0.00962054, -0.00944432, -0.00387812, -0.00244917,
                       0.00085404]]))
    filters['bfilts'] = (
        np.array([[0.00277643, 0.00496194, 0.01026699, 0.01455399, 0.01026699,
                   0.00496194, 0.00277643, -0.00986904, -0.00893064,
                   0.01189859, 0.02755155, 0.01189859, -0.00893064,
                   -0.00986904, -0.01021852, -0.03075356, -0.08226445,
                   -0.11732297, -0.08226445, -0.03075356, -0.01021852,
                   0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
                   0.00000000, 0.00000000, 0.01021852, 0.03075356, 0.08226445,
                   0.11732297, 0.08226445, 0.03075356, 0.01021852, 0.00986904,
                   0.00893064, -0.01189859, -0.02755155, -0.01189859,
                   0.00893064, 0.00986904, -0.00277643, -0.00496194,
                   -0.01026699, -0.01455399, -0.01026699, -0.00496194,
                   -0.00277643],
                  [-0.00343249, -0.00640815, -0.00073141, 0.01124321,
                   0.00182078, 0.00285723, 0.01166982, -0.00358461,
                   -0.01977507, -0.04084211, -0.00228219, 0.03930573,
                   0.01161195, 0.00128000, 0.01047717, 0.01486305,
                   -0.04819057, -0.12227230, -0.05394139, 0.00853965,
                   -0.00459034, 0.00790407, 0.04435647, 0.09454202,
                   -0.00000000, -0.09454202, -0.04435647, -0.00790407,
                   0.00459034, -0.00853965, 0.05394139, 0.12227230,
                   0.04819057, -0.01486305, -0.01047717, -0.00128000,
                   -0.01161195, -0.03930573, 0.00228219, 0.04084211,
                   0.01977507, 0.00358461, -0.01166982, -0.00285723,
                   -0.00182078, -0.01124321, 0.00073141, 0.00640815,
                   0.00343249],
                  [0.00343249, 0.00358461, -0.01047717, -0.00790407,
                   -0.00459034, 0.00128000, 0.01166982, 0.00640815,
                   0.01977507, -0.01486305, -0.04435647, 0.00853965,
                   0.01161195, 0.00285723, 0.00073141, 0.04084211, 0.04819057,
                   -0.09454202, -0.05394139, 0.03930573, 0.00182078,
                   -0.01124321, 0.00228219, 0.12227230, -0.00000000,
                   -0.12227230, -0.00228219, 0.01124321, -0.00182078,
                   -0.03930573, 0.05394139, 0.09454202, -0.04819057,
                   -0.04084211, -0.00073141, -0.00285723, -0.01161195,
                   -0.00853965, 0.04435647, 0.01486305, -0.01977507,
                   -0.00640815, -0.01166982, -0.00128000, 0.00459034,
                   0.00790407, 0.01047717, -0.00358461, -0.00343249],
                  [-0.00277643, 0.00986904, 0.01021852, -0.00000000,
                   -0.01021852, -0.00986904, 0.00277643, -0.00496194,
                   0.00893064, 0.03075356, -0.00000000, -0.03075356,
                   -0.00893064, 0.00496194, -0.01026699, -0.01189859,
                   0.08226445, -0.00000000, -0.08226445, 0.01189859,
                   0.01026699, -0.01455399, -0.02755155, 0.11732297,
                   -0.00000000, -0.11732297, 0.02755155, 0.01455399,
                   -0.01026699, -0.01189859, 0.08226445, -0.00000000,
                   -0.08226445, 0.01189859, 0.01026699, -0.00496194,
                   0.00893064, 0.03075356, -0.00000000, -0.03075356,
                   -0.00893064, 0.00496194, -0.00277643, 0.00986904,
                   0.01021852, -0.00000000, -0.01021852, -0.00986904,
                   0.00277643],
                  [-0.01166982, -0.00128000, 0.00459034, 0.00790407,
                   0.01047717, -0.00358461, -0.00343249, -0.00285723,
                   -0.01161195, -0.00853965, 0.04435647, 0.01486305,
                   -0.01977507, -0.00640815, -0.00182078, -0.03930573,
                   0.05394139, 0.09454202, -0.04819057, -0.04084211,
                   -0.00073141, -0.01124321, 0.00228219, 0.12227230,
                   -0.00000000, -0.12227230, -0.00228219, 0.01124321,
                   0.00073141, 0.04084211, 0.04819057, -0.09454202,
                   -0.05394139, 0.03930573, 0.00182078, 0.00640815,
                   0.01977507, -0.01486305, -0.04435647, 0.00853965,
                   0.01161195, 0.00285723, 0.00343249, 0.00358461,
                   -0.01047717, -0.00790407, -0.00459034, 0.00128000,
                   0.01166982],
                  [-0.01166982, -0.00285723, -0.00182078, -0.01124321,
                   0.00073141, 0.00640815, 0.00343249, -0.00128000,
                   -0.01161195, -0.03930573, 0.00228219, 0.04084211,
                   0.01977507, 0.00358461, 0.00459034, -0.00853965,
                   0.05394139, 0.12227230, 0.04819057, -0.01486305,
                   -0.01047717, 0.00790407, 0.04435647, 0.09454202,
                   -0.00000000, -0.09454202, -0.04435647, -0.00790407,
                   0.01047717, 0.01486305, -0.04819057, -0.12227230,
                   -0.05394139, 0.00853965, -0.00459034, -0.00358461,
                   -0.01977507, -0.04084211, -0.00228219, 0.03930573,
                   0.01161195, 0.00128000, -0.00343249, -0.00640815,
                   -0.00073141, 0.01124321, 0.00182078, 0.00285723,
                   0.01166982]]).T)
    return filters